Add FlashInfer allreduce RMSNorm Quant fusion (#21069)#32
Add FlashInfer allreduce RMSNorm Quant fusion (#21069)#32MitchLewis930 wants to merge 1 commit intoROCM_bug_beforefrom
Conversation
Signed-off-by: ilmarkov <imarkov@redhat.com> Signed-off-by: ilmarkov <markovilya197@gmail.com> Co-authored-by: ilmarkov <imarkov@redhat.com>
📝 WalkthroughWalkthroughThe pull request extends all-reduce fusion capabilities to support static FP8 and FP4 quantization paths combined with RMS normalization. It introduces new pattern classes for quantized fusion operations, adds test infrastructure for quantization variants, wires new compilation passes, and adjusts token limits for fusion optimization. Changes
Sequence Diagram(s)sequenceDiagram
participant Test as Test Runner
participant Pass as AllReduceFusionPass
participant PatternMatcher as Pattern Matcher
participant Backend as TestBackend
participant Runtime as TRTLLM Runtime
Test->>Pass: Initialize with vllm_config
Pass->>Pass: Compute max_num_token from model/token config
Pass->>Pass: Register AllReduceRMSNormPattern (eps1, eps2)
Pass->>Pass: Register AllReduceFusedRMSNormStaticQuantFP8Pattern (eps1, eps2)
Pass->>Pass: Register AllReduceFusedAddRMSNormStaticQuantFP8Pattern (eps1, eps2)
alt Device supports NVFP4
Pass->>Pass: Register NVFP4 pattern variants (eps1, eps2)
end
Pass->>PatternMatcher: Clear inductor cache after each epsilon
Test->>Backend: Compile model with fusion passes
Backend->>PatternMatcher: Match patterns in graph
PatternMatcher->>Backend: Return matched fusion opportunities
Backend->>Runtime: Execute fused all-reduce+norm+quant operations
Runtime->>Test: Return fused computation results
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/compile/test_fusion_all_reduce.py (1)
29-34: Silence unusedtoken_numargs to keep Ruff clean.Ruff ARG002 flags these constructor parameters as unused. Either store them or rename to
_token_numto keep lint green.✅ Suggested fix
class TestAllReduceRMSNormModel(torch.nn.Module): def __init__(self, hidden_size=16, token_num=16, eps=1e-6): super().__init__() self.hidden_size = hidden_size self.eps = eps self.norm = RMSNorm(hidden_size, eps) + self.token_num = token_num @@ class TestAllReduceFusedAddRMSNormModel(torch.nn.Module): def __init__(self, hidden_size=16, token_num=16, eps=1e-6): super().__init__() self.hidden_size = hidden_size self.eps = eps self.norm = RMSNorm(hidden_size, eps) + self.token_num = token_numAlso applies to: 50-55
🤖 Fix all issues with AI agents
In `@tests/compile/test_fusion_all_reduce.py`:
- Around line 112-116: The inline lambda named round_up should be replaced with
the shared helper to avoid defining lambdas in expressions; locate the lambda
used to compute rounded_m/rounded_n (the round_up = lambda x, y: ... and its
uses before self.output_scale = torch.empty(...)) and replace it with a call to
vllm.utils.round_up (or define a local def round_up(x, y): ...) so the same
rounding logic is reused and the lambda assignment is removed, then update the
calls computing rounded_m and rounded_n to use the helper.
In `@vllm/compilation/collective_fusion.py`:
- Around line 499-503: Fix the typo in the comment ("outpput" -> "output") and
avoid the subtle logic bug by tracking whether norm_out was originally provided
before it gets reassigned: capture a boolean like original_norm_out_provided =
(norm_out is not None) immediately before any internal assignment to norm_out
(around where norm_out is assigned at/near the code referencing line 488), then
replace the condition `if scale_factor is None or norm_out is not None:` with
`if scale_factor is None or original_norm_out_provided:` so the copy_
(allreduce_in.copy_(allreduce_out)) only runs when scale_factor is None or
norm_out was passed in by the caller; update the comment to read "output" and
briefly explain the check uses the original-provided flag.
🧹 Nitpick comments (3)
vllm/compilation/collective_fusion.py (3)
704-780: Consider consistency inget_inputsplacement.The
get_inputs()function is defined insideregister()rather than as a class method like inAllReduceRMSNormPatternand other base patterns. While functional, this is inconsistent with the established pattern style in this file.The pattern logic and return indices are correct.
1084-1088: Add bounds checking formax_num_token.If
hidden_dim * tp_sizeis very large, the integer division could result in 0, which would likely cause issues downstream. Consider adding a minimum bound or guard.Proposed fix
max_num_token = min( _FI_MAX_SIZES.get(self.tp_size, _DEFAULT_FI_MAX_SIZE) // (self.hidden_dim * self.tp_size * (4 if use_fp32_lamport else 2)), config.compilation_config.pass_config. fi_allreduce_fusion_max_token_num) + if max_num_token <= 0: + logger.warning( + "Computed max_num_token is %d, skipping allreduce fusion pass", + max_num_token) + return
1149-1151: Add version guard and document PyTorch compatibility for_seen_patternscache hack.The
torch._inductor.pattern_matcher._seen_patternsis a private internal API not covered by PyTorch's stability guarantees and may change across versions. While pinned totorch == 2.7.1in the repository, consider wrapping this in a version check (e.g.,if torch.__version__.startswith("2.7")) and documenting the minimum PyTorch version tested for this specific pattern, consistent with the version guard pattern used elsewhere invllm/compilation/compiler_interface.py.
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
.buildkite/test-pipeline.yamltests/compile/test_fusion_all_reduce.pytests/utils.pyvllm/compilation/collective_fusion.pyvllm/config.py
🧰 Additional context used
🧬 Code graph analysis (2)
tests/compile/test_fusion_all_reduce.py (9)
vllm/compilation/fix_functionalization.py (1)
FixFunctionalizationPass(19-191)vllm/compilation/noop_elimination.py (1)
NoOpEliminationPass(18-165)vllm/distributed/communication_op.py (1)
tensor_model_parallel_all_reduce(12-14)vllm/model_executor/layers/layernorm.py (1)
RMSNorm(89-200)vllm/model_executor/layers/quantization/utils/quant_utils.py (1)
GroupShape(24-32)vllm/model_executor/layers/quantization/input_quant_fp8.py (1)
QuantFP8(24-103)tests/utils.py (2)
has_module_attribute(980-988)multi_gpu_test(893-906)vllm/utils/__init__.py (1)
round_up(981-982)vllm/model_executor/layers/quantization/rtn.py (1)
shape(104-111)
vllm/compilation/collective_fusion.py (2)
vllm/distributed/communication_op.py (1)
tensor_model_parallel_all_reduce(12-14)vllm/distributed/parallel_state.py (2)
all_reduce(105-110)all_reduce(341-364)
🪛 Ruff (0.14.13)
tests/compile/test_fusion_all_reduce.py
29-29: Unused method argument: token_num
(ARG002)
50-50: Unused method argument: token_num
(ARG002)
112-112: Do not assign a lambda expression, use a def
Rewrite round_up as a def
(E731)
🔇 Additional comments (11)
.buildkite/test-pipeline.yaml (1)
356-356: LGTM!The new test file
compile/test_fusion_all_reduce.pyis appropriately added to the PyTorch Compilation Unit Tests section, which aligns with the fusion pass functionality being tested.vllm/compilation/collective_fusion.py (10)
40-41: LGTM!The new quantization operation constants follow the established naming convention and are appropriately defined at module scope for use in the pattern classes.
524-536: LGTM!The custom op registration correctly declares all mutated arguments including the new
quant_outandscale_outparameters.
541-571: LGTM!The
FlashInferFusedAllReduceParamsclass is cleanly extended with the newfuse_rms_quantparameter while maintaining backward compatibility through the default value.
574-636: LGTM!The class naming is now consistent (PascalCase for "Norm"), and the pattern replacement correctly uses the fused operation with proper return tuple indices matching the
mutates_argsdeclaration.
639-701: LGTM!The pattern correctly handles the fused add RMS norm case where
norm_out=Noneindicates in-place operation, and the return indices properly extractallreduce_inandresidualfrom the result tuple.
783-868: LGTM!The FP8 quantization pattern with residual addition is correctly implemented. The return indices properly extract
quant_outand the updatedresidual.
871-960: LGTM!The NVFP4 quantization pattern correctly handles the three-output case (quantized output, allreduce output, and output scale) with proper return indices.
963-1052: LGTM!The NVFP4 pattern with residual addition is correctly implemented with appropriate return values for the three outputs.
1123-1135: LGTM on device capability gating.The NVFP4 patterns are appropriately gated behind
has_device_capability(100)for Blackwell GPUs.
1165-1169: LGTM!The cleanup logic properly destroys IPC workspace resources when the pass is deleted.
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
| round_up = lambda x, y: (x + y - 1) // y * y | ||
| rounded_m = round_up(token_num, 128) | ||
| scale_n = hidden_size // 16 | ||
| rounded_n = round_up(scale_n, 4) | ||
| self.output_scale = torch.empty((rounded_m, rounded_n // 4), |
There was a problem hiding this comment.
Replace the inline round_up lambda (Ruff E731).
Use the existing vllm.utils.round_up helper (or a local def) instead of assigning a lambda.
✅ Suggested fix
-from vllm.utils import update_environment_variables
+from vllm.utils import round_up, update_environment_variables
@@
- round_up = lambda x, y: (x + y - 1) // y * y
rounded_m = round_up(token_num, 128)
scale_n = hidden_size // 16
rounded_n = round_up(scale_n, 4)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| round_up = lambda x, y: (x + y - 1) // y * y | |
| rounded_m = round_up(token_num, 128) | |
| scale_n = hidden_size // 16 | |
| rounded_n = round_up(scale_n, 4) | |
| self.output_scale = torch.empty((rounded_m, rounded_n // 4), | |
| rounded_m = round_up(token_num, 128) | |
| scale_n = hidden_size // 16 | |
| rounded_n = round_up(scale_n, 4) | |
| self.output_scale = torch.empty((rounded_m, rounded_n // 4), |
🧰 Tools
🪛 Ruff (0.14.13)
112-112: Do not assign a lambda expression, use a def
Rewrite round_up as a def
(E731)
🤖 Prompt for AI Agents
In `@tests/compile/test_fusion_all_reduce.py` around lines 112 - 116, The inline
lambda named round_up should be replaced with the shared helper to avoid
defining lambdas in expressions; locate the lambda used to compute
rounded_m/rounded_n (the round_up = lambda x, y: ... and its uses before
self.output_scale = torch.empty(...)) and replace it with a call to
vllm.utils.round_up (or define a local def round_up(x, y): ...) so the same
rounding logic is reused and the lambda assignment is removed, then update the
calls computing rounded_m and rounded_n to use the helper.
| if scale_factor is None or norm_out is not None: | ||
| # we need to return allreduce outpput | ||
| # in cases of non quant fused AR + RMS norm | ||
| # and fused AR + RMS norm + quant without fused add | ||
| allreduce_in.copy_(allreduce_out) |
There was a problem hiding this comment.
Typo in comment and potential logic clarity issue.
Line 500 has a typo: "outpput" should be "output".
Also, the condition norm_out is not None on line 499 will now always be true in the else branch because norm_out is assigned at line 488 when it was originally None. This means the copy_ always executes in the non-fused path when scale_factor is None. Consider clarifying with a local variable to track whether norm_out was originally provided vs assigned internally.
Proposed fix for typo
- if scale_factor is None or norm_out is not None:
- # we need to return allreduce outpput
+ if scale_factor is None or norm_out is not None:
+ # we need to return allreduce output🤖 Prompt for AI Agents
In `@vllm/compilation/collective_fusion.py` around lines 499 - 503, Fix the typo
in the comment ("outpput" -> "output") and avoid the subtle logic bug by
tracking whether norm_out was originally provided before it gets reassigned:
capture a boolean like original_norm_out_provided = (norm_out is not None)
immediately before any internal assignment to norm_out (around where norm_out is
assigned at/near the code referencing line 488), then replace the condition `if
scale_factor is None or norm_out is not None:` with `if scale_factor is None or
original_norm_out_provided:` so the copy_ (allreduce_in.copy_(allreduce_out))
only runs when scale_factor is None or norm_out was passed in by the caller;
update the comment to read "output" and briefly explain the check uses the
original-provided flag.
test
Summary by CodeRabbit
Release Notes
Tests
Features
Chores
✏️ Tip: You can customize this high-level summary in your review settings.